package cloud.seri.gateway.web.filter

import cloud.seri.gateway.security.oauth2.OAuth2AuthenticationService
import cloud.seri.gateway.security.oauth2.OAuth2CookieHelper
import org.slf4j.LoggerFactory
import org.springframework.security.oauth2.common.exceptions.ClientAuthenticationException
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException
import org.springframework.security.oauth2.common.exceptions.UnauthorizedClientException
import org.springframework.security.oauth2.provider.token.TokenStore
import org.springframework.web.client.HttpClientErrorException
import org.springframework.web.filter.GenericFilterBean

import javax.servlet.FilterChain
import javax.servlet.ServletException
import javax.servlet.ServletRequest
import javax.servlet.ServletResponse
import javax.servlet.http.Cookie
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import java.io.IOException

/**
 * Filters incoming requests and refreshes the access token before it expires.
 */
class RefreshTokenFilter(
    /**
     * The [OAuth2AuthenticationService] is doing the actual work. We are just a simple filter after all.
     */
    private val authenticationService: OAuth2AuthenticationService,
    private val tokenStore: TokenStore
) : GenericFilterBean() {

    private val log = LoggerFactory.getLogger(RefreshTokenFilter::class.java)

    /**
     * Check access token cookie and refresh it, if it is either not present, expired or about to expire.
     */
    @Throws(IOException::class, ServletException::class)
    override fun doFilter(servletRequest: ServletRequest, servletResponse: ServletResponse, filterChain: FilterChain) {
        var httpServletRequest = servletRequest as HttpServletRequest
        val httpServletResponse = servletResponse as HttpServletResponse
        httpServletRequest = try {
            refreshTokensIfExpiring(httpServletRequest, httpServletResponse)
        } catch (ex: ClientAuthenticationException) {
            log.warn("Security exception: could not refresh tokens", ex)
            authenticationService.stripTokens(httpServletRequest)
        }

        filterChain.doFilter(httpServletRequest, servletResponse)
    }

    /**
     * Refresh the access and refresh tokens if they are about to expire.
     *
     * @param httpServletRequest the servlet request holding the current cookies. If no refresh cookie is present,
     * then we are out of luck.
     * @param httpServletResponse the servlet response that gets the new set-cookie headers, if they had to be
     * refreshed.
     * @return a new request to use downstream that contains the new cookies, if they had to be refreshed.
     * @throws InvalidTokenException if the tokens could not be refreshed.
     */
    fun refreshTokensIfExpiring(httpServletRequest: HttpServletRequest, httpServletResponse: HttpServletResponse): HttpServletRequest {
        var newHttpServletRequest = httpServletRequest
        //get access token from cookie
        val accessTokenCookie = OAuth2CookieHelper.getAccessTokenCookie(httpServletRequest)
        if (mustRefreshToken(accessTokenCookie)) {        //we either have no access token, or it is expired, or it is about to expire
            //get the refresh token cookie and, if present, request new tokens
            val refreshCookie = OAuth2CookieHelper.getRefreshTokenCookie(httpServletRequest)
            if (refreshCookie != null) {
                try {
                    newHttpServletRequest = authenticationService.refreshToken(httpServletRequest, httpServletResponse, refreshCookie)
                } catch (ex: HttpClientErrorException) {
                    throw UnauthorizedClientException("could not refresh OAuth2 token", ex)
                }

            } else if (accessTokenCookie != null) {
                log.warn("access token found, but no refresh token, stripping them all")
                val token = tokenStore.readAccessToken(accessTokenCookie.value)
                if (token.isExpired) {
                    throw InvalidTokenException("access token has expired, but there's no refresh token")
                }
            }
        }
        return newHttpServletRequest
    }

    /**
     * Check if we must refresh the access token.
     * We must refresh it, if we either have no access token, or it is expired, or it is about to expire.
     *
     * @param accessTokenCookie the current access token.
     * @return true, if it must be refreshed; false, otherwise.
     */
    private fun mustRefreshToken(accessTokenCookie: Cookie?): Boolean {
        if (accessTokenCookie == null) {
            return true
        }
        val token = tokenStore.readAccessToken(accessTokenCookie.value)
        //check if token is expired or about to expire
        return token.isExpired || token.expiresIn < REFRESH_WINDOW_SECS
    }

    companion object {
        /**
         * Number of seconds before expiry to start refreshing access tokens so we don't run into race conditions when forwarding
         * requests downstream. Otherwise, access tokens may still be valid when we check here but may then be expired
         * when relayed to another microservice a wee bit later.
         */
        private const val REFRESH_WINDOW_SECS = 30
    }
}
